{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e8896660-62a3-4bbc-b80b-40ffddf38cba",
   "metadata": {},
   "source": [
    "<center><h1>濒危大型动物种类识别挑战赛</h1></center>\n",
    "\n",
    "> 报名链接：https://challenge.xfyun.cn/topic/info?type=species-recognition&option=ssgy&ch=dw24_AtTCK9\n",
    "\n",
    "\n",
    "# 一、赛事背景\n",
    "\n",
    "全球生物多样性持续下降，尤其濒危大型动物的数量锐减，成为全球环保的紧急议题。濒危动物如大象、犀牛及大型猫科动物面临众多威胁，如栖息地丧失、非法狩猎等，这不仅威胁到生物多样性，也影响生态平衡。为此，精确及时的监测和保护措施显得至关重要。然而，传统监测方法耗时且效率低，难以满足迫切需求。借助人工智能技术，开发自动化的濒危大型动物种类识别系统，通过分析图像数据自动识别特定动物，将极大提升保护效率和准确性。因此，举办一场濒危大型动物种类识别挑战赛，具有重要的研究意义，旨在推动AI技术在野生动物保护中的应用，提高公众保护意识，促进全球生态保护工作。\n",
    "\n",
    "# 二、赛事任务\n",
    "\n",
    "本次濒危大型动物种类识别挑战赛旨在应用人工智能技术，提高对濒危大型动物的监测与保护效率。赛事提供了包括照相陷阱、无人机等多源采集的动物图像数据及其所在保护区的地理位置信息。参赛者需基于这些数据，开发出能够准确识别濒危大型动物种类的模型。挑战的关键在于处理和识别在不同光照、角度和背景条件下相似物种。\n",
    "\n",
    "# 三、评审规则\n",
    "\n",
    "## 1.数据说明\n",
    "\n",
    "本次濒危大型动物种类识别挑战赛为参赛选手提供了一个包含9种不同濒危动物的庞大图像数据库，这些图像通过照相陷阱、无人机等多样化的采集方法，在不同的时间、光照条件和多样化的自然背景下收集，旨在模拟实际野外监测的多种情形。\n",
    "\n",
    "（1）数据集细分说明：\n",
    "\n",
    "训练集用于模型训练，提供动物图像及其对应的类别标签；测试集用于评估模型性能，仅包含图像数据，参赛者需要预测图像对应的动物种类。\n",
    "\n",
    "（2）种类识别数据集具体分类：\n",
    "\n",
    "| Label       |\n",
    "| ----------- |\n",
    "| Badger      |\n",
    "| BlackBear   |\n",
    "| Cheetah     |\n",
    "| Hare        |\n",
    "| LeopardCat  |\n",
    "| MuskDeer    |\n",
    "| AmurLeopard |\n",
    "| Tiger       |\n",
    "| RedFox      |\n",
    "\n",
    "## 2.评估指标\n",
    "\n",
    "本模型依据提交的结果文件，采用macro-F1分数进行评价。\n",
    "\n",
    "## 3.评测及排行\n",
    "\n",
    "（1）初赛和复赛均提供下载数据，选手在本地进行算法调试，在比赛页面提交结果。\n",
    "\n",
    "（2）比赛采用AB榜，A榜成绩供参赛队伍比赛中查看，最终比赛排名采用B榜最佳成绩。\n",
    "\n",
    "# 四、作品提交要求\n",
    "\n",
    "1、文件格式：按照csv格式提交\n",
    "\n",
    "2、文件大小：无要求\n",
    "\n",
    "3、提交次数限制：每支队伍每天最多3次\n",
    "\n",
    "4、文件详细说明：编码为UTF-8，第一行为表头，提交格式见样例\n",
    "\n",
    "5、不需要上传其他文件\n",
    "\n",
    "# 五、赛程规则\n",
    "\n",
    "本赛题实行一轮赛制\n",
    "\n",
    "## 【赛程周期】\n",
    "\n",
    "7月17日-8月29日\n",
    "\n",
    "1、7月17日10：00发布训练集、开发集、即开启比赛榜单，8月15日发布B榜测试\n",
    "\n",
    "2、比赛作品提交截止日期为8月29日17：00，公布名次日期为9月9日10:00\n",
    "\n",
    "## 【现场答辩】\n",
    "\n",
    "1、最终前三名团队将受邀参加科大讯飞全球1024开发者节并于现场进行答辩\n",
    "\n",
    "2、答辩以（10mins陈述+5mins问答）的形式进行\n",
    "\n",
    "3、根据作品成绩和答辩成绩综合评分（作品成绩占比70％，现场答辩分数占比30％）\n",
    "\n",
    "# 六、奖项设置\n",
    "\n",
    "本赛题设立一、二、三等奖共三名，具体详情如下：\n",
    "\n",
    "## 【奖项激励】\n",
    "\n",
    "1.  TOP3团队颁发获奖证书\n",
    "2.  赛道奖金，第一名5000元、第二名3000元、第三名2000元\n",
    "\n",
    "## 【资源激励】\n",
    "\n",
    "1.  讯飞开放平台优质AI能力个人资源包\n",
    "2.  讯飞AI全链创业扶持资源\n",
    "3.  讯飞绿色实习/就业通道\n",
    "\n",
    "注：\n",
    "\n",
    "1.  鼓励选手分享参赛心得、参赛技术攻略、大赛相关技术或产品使用体验等文章至组委会邮箱（AICompetition@iflytek.com），有机会获得大赛周边；\n",
    "2.  赛事规则及奖金发放解释权归科大讯飞所有；以上全部奖金均为税前金额，将由主办方代扣代缴个人所得税。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "69cfe08a-db6c-49f2-adc2-d52cb1387e99",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://mirrors.aliyun.com/pypi/simple\n",
      "Requirement already satisfied: timm in ./miniconda3/lib/python3.8/site-packages (1.0.7)\n",
      "Collecting pandas\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/f8/7f/5b047effafbdd34e52c9e2d7e44f729a0655efafb22198c45cf692cdc157/pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n",
      "\u001b[K     |████████████████████████████████| 12.4 MB 561 kB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: pyyaml in ./miniconda3/lib/python3.8/site-packages (from timm) (6.0.1)\n",
      "Requirement already satisfied: huggingface_hub in ./miniconda3/lib/python3.8/site-packages (from timm) (0.23.5)\n",
      "Requirement already satisfied: torch in ./miniconda3/lib/python3.8/site-packages (from timm) (1.10.0+cu113)\n",
      "Requirement already satisfied: safetensors in ./miniconda3/lib/python3.8/site-packages (from timm) (0.4.3)\n",
      "Requirement already satisfied: torchvision in ./miniconda3/lib/python3.8/site-packages (from timm) (0.11.1+cu113)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in ./miniconda3/lib/python3.8/site-packages (from pandas) (2.8.2)\n",
      "Collecting tzdata>=2022.1\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/65/58/f9c9e6be752e9fcb8b6a0ee9fb87e6e7a1f6bcab2cdc73f02bb7ba91ada0/tzdata-2024.1-py2.py3-none-any.whl (345 kB)\n",
      "\u001b[K     |████████████████████████████████| 345 kB 550 kB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.20.3 in ./miniconda3/lib/python3.8/site-packages (from pandas) (1.21.4)\n",
      "Requirement already satisfied: pytz>=2020.1 in ./miniconda3/lib/python3.8/site-packages (from pandas) (2021.3)\n",
      "Requirement already satisfied: six>=1.5 in ./miniconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
      "Requirement already satisfied: requests in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (2.25.1)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (4.0.0)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (2024.6.1)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (4.61.2)\n",
      "Requirement already satisfied: packaging>=20.9 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (21.3)\n",
      "Requirement already satisfied: filelock in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (3.15.4)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in ./miniconda3/lib/python3.8/site-packages (from packaging>=20.9->huggingface_hub->timm) (3.0.6)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (1.26.6)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (2021.5.30)\n",
      "Requirement already satisfied: idna<3,>=2.5 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (2.10)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (4.0.0)\n",
      "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in ./miniconda3/lib/python3.8/site-packages (from torchvision->timm) (8.4.0)\n",
      "Installing collected packages: tzdata, pandas\n",
      "Successfully installed pandas-2.0.3 tzdata-2024.1\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install timm pandas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c64d3134-4dec-44ff-9f5e-55b033c48edb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import time\n",
    "import glob\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "01acdbde-f012-45db-ae94-beff97baa8f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_list = ['Badger', 'BlackBear', 'Cheetah', 'Hare', 'LeopardCat', 'MuskDeer', 'AmurLeopard', 'Tiger', 'RedFox']\n",
    "\n",
    "train_path = glob.glob('./train/*/*.jpg')\n",
    "np.random.shuffle(train_path)\n",
    "train_label = [label_list.index(x.split('/')[-2]) for x in train_path]\n",
    "\n",
    "test_path = glob.glob('./testA/*.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ab48904c-3f33-4299-a6ab-c44ee9a40670",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "    def __init__(self, name, fmt=':f'):\n",
    "        self.name = name\n",
    "        self.fmt = fmt\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count\n",
    "\n",
    "    def __str__(self):\n",
    "        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n",
    "        return fmtstr.format(**self.__dict__)\n",
    "\n",
    "class ProgressMeter(object):\n",
    "    def __init__(self, num_batches, *meters):\n",
    "        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n",
    "        self.meters = meters\n",
    "        self.prefix = \"\"\n",
    "\n",
    "\n",
    "    def pr2int(self, batch):\n",
    "        entries = [self.prefix + self.batch_fmtstr.format(batch)]\n",
    "        entries += [str(meter) for meter in self.meters]\n",
    "        print('\\t'.join(entries))\n",
    "\n",
    "    def _get_batch_fmtstr(self, num_batches):\n",
    "        num_digits = len(str(num_batches // 1))\n",
    "        fmt = '{:' + str(num_digits) + 'd}'\n",
    "        return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n",
    "def validate(val_loader, model, criterion):\n",
    "    batch_time = AverageMeter('Time', ':6.3f')\n",
    "    losses = AverageMeter('Loss', ':.4e')\n",
    "    top1 = AverageMeter('Acc@1', ':6.2f')\n",
    "    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)\n",
    "\n",
    "    # switch to evaluate mode\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "\n",
    "            # measure accuracy and record loss\n",
    "            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n",
    "            losses.update(loss.item(), input.size(0))\n",
    "            top1.update(acc, input.size(0))\n",
    "            # measure elapsed time\n",
    "            batch_time.update(time.time() - end)\n",
    "            end = time.time()\n",
    "\n",
    "        # TODO: this should also be done with the ProgressMeter\n",
    "        print(' * Acc@1 {top1.avg:.3f}'\n",
    "              .format(top1=top1))\n",
    "        return top1\n",
    "\n",
    "def predict(test_loader, model, tta=10):\n",
    "    # switch to evaluate mode\n",
    "    model.eval()\n",
    "    \n",
    "    test_pred_tta = None\n",
    "    for _ in range(tta):\n",
    "        test_pred = []\n",
    "        with torch.no_grad():\n",
    "            end = time.time()\n",
    "            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n",
    "                input = input.cuda()\n",
    "                target = target.cuda()\n",
    "\n",
    "                # compute output\n",
    "                output = model(input)\n",
    "                output = F.softmax(output, dim=1)\n",
    "                output = output.data.cpu().numpy()\n",
    "\n",
    "                test_pred.append(output)\n",
    "        test_pred = np.vstack(test_pred)\n",
    "    \n",
    "        if test_pred_tta is None:\n",
    "            test_pred_tta = test_pred\n",
    "        else:\n",
    "            test_pred_tta += test_pred\n",
    "    \n",
    "    return test_pred_tta\n",
    "\n",
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    batch_time = AverageMeter('Time', ':6.3f')\n",
    "    losses = AverageMeter('Loss', ':.4e')\n",
    "    top1 = AverageMeter('Acc@1', ':6.2f')\n",
    "    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)\n",
    "\n",
    "    # switch to train mode\n",
    "    model.train()\n",
    "\n",
    "    end = time.time()\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # measure accuracy and record loss\n",
    "        losses.update(loss.item(), input.size(0))\n",
    "\n",
    "        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n",
    "        top1.update(acc, input.size(0))\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # measure elapsed time\n",
    "        batch_time.update(time.time() - end)\n",
    "        end = time.time()\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            progress.pr2int(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dde84a60-5939-4b69-b0a9-bb43773186b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class XFDataset(Dataset):\n",
    "    def __init__(self, img_path, img_label, transform=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label\n",
    "        \n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        return img, torch.from_numpy(np.array(self.img_label[index]))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aad4267f-b007-48a4-8e06-cf89df9b9f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "\n",
    "import timm\n",
    "model = timm.create_model('efficientnet_b1', pretrained=False, num_classes=9)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "abc0c12c-5cdd-42fd-bfd6-9027431e46e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
      "  warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0/96]\tTime  2.672 ( 2.672)\tLoss 4.4935e+00 (4.4935e+00)\tAcc@1  10.00 ( 10.00)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1641/2676678360.py:51: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3fac89d56fcf4143ae10b4d22c1aaf45",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 79.000\n",
      "Epoch:  1\n",
      "[ 0/96]\tTime  2.400 ( 2.400)\tLoss 6.7960e-01 (6.7960e-01)\tAcc@1  80.00 ( 80.00)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd910c3c512142859b1de658f8d12fc1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 81.800\n",
      "Epoch:  2\n",
      "[ 0/96]\tTime  2.611 ( 2.611)\tLoss 7.5036e-01 (7.5036e-01)\tAcc@1  75.71 ( 75.71)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "044ab8b0435f454dbb91c20e79578a5a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 76.200\n",
      "Epoch:  3\n",
      "[ 0/96]\tTime  2.459 ( 2.459)\tLoss 5.4423e-01 (5.4423e-01)\tAcc@1  84.29 ( 84.29)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a76adb6d5b3403081abb8931cbbb12c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 88.800\n",
      "Epoch:  4\n",
      "[ 0/96]\tTime  2.294 ( 2.294)\tLoss 1.5976e-01 (1.5976e-01)\tAcc@1  94.29 ( 94.29)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4482a57869b643e09f569c2d2793de03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 89.800\n",
      "Epoch:  5\n",
      "[ 0/96]\tTime  2.563 ( 2.563)\tLoss 3.0123e-01 (3.0123e-01)\tAcc@1  92.86 ( 92.86)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bbd29fb2592047a8b6535229c2ac14a7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 89.800\n",
      "Epoch:  6\n",
      "[ 0/96]\tTime  2.137 ( 2.137)\tLoss 2.1807e-01 (2.1807e-01)\tAcc@1  91.43 ( 91.43)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "430ffe3b037a49738905bbabdfb7832a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 88.600\n",
      "Epoch:  7\n",
      "[ 0/96]\tTime  2.508 ( 2.508)\tLoss 3.7349e-01 (3.7349e-01)\tAcc@1  88.57 ( 88.57)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1641/1845096411.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     30\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Epoch: '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m     \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m     \u001b[0mval_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1641/2676678360.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(train_loader, model, criterion, optimizer, epoch)\u001b[0m\n\u001b[1;32m    108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    109\u001b[0m     \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    111\u001b[0m         \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    519\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1185\u001b[0m             \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_shutdown\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tasks_outstanding\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1186\u001b[0;31m             \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1187\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tasks_outstanding\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1188\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1140\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1141\u001b[0m             \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory_thread\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_alive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1142\u001b[0;31m                 \u001b[0msuccess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_get_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1143\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0msuccess\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1144\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    988\u001b[0m         \u001b[0;31m#   (bool: whether successfully get data, any: data if successful else None)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    989\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_queue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    991\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    992\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.8/queue.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m    177\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0mremaining\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m                         \u001b[0;32mraise\u001b[0m \u001b[0mEmpty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnot_empty\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mremaining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    180\u001b[0m             \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnot_full\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.8/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    304\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    305\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 306\u001b[0;31m                     \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    307\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    308\u001b[0m                     \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(train_path[:-500], train_label[:-500], \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.RandomHorizontalFlip(),\n",
    "                        transforms.RandomVerticalFlip(),\n",
    "                        transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=70, shuffle=True, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(train_path[-500:], train_label[-500:], \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=70, shuffle=False, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 0.007)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)\n",
    "best_acc = 0.0\n",
    "for epoch in range(20):\n",
    "    scheduler.step()\n",
    "    print('Epoch: ', epoch)\n",
    "\n",
    "    train(train_loader, model, criterion, optimizer, epoch)\n",
    "    val_acc = validate(val_loader, model, criterion)\n",
    "    \n",
    "    if val_acc.avg.item() > best_acc:\n",
    "        best_acc = round(val_acc.avg.item(), 2)\n",
    "        torch.save(model.state_dict(), f'./model_{best_acc}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "41e86d06-e364-4e3f-8631-c2f3c0722e35",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1641/2676678360.py:81: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b8b22c4bf5ac4b598ed6bac9b8186218",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/60 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(test_path, [0] * len(test_path), \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "val_label = pd.DataFrame()\n",
    "val_label['uuid'] = [x.split('/')[-1] for x in test_path]\n",
    "val_label['y_pred'] = predict(test_loader, model, 1).argmax(1)\n",
    "val_label['label'] = val_label['y_pred'].apply(lambda x: label_list[x])\n",
    "val_label[['uuid', 'label']].to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5dc9d50b-9959-4cd7-acfb-91644ccddac6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0            RedFox\n",
       "1            Badger\n",
       "2            RedFox\n",
       "3        LeopardCat\n",
       "4       AmurLeopard\n",
       "           ...     \n",
       "2377           Hare\n",
       "2378    AmurLeopard\n",
       "2379           Hare\n",
       "2380         Badger\n",
       "2381     LeopardCat\n",
       "Name: y_pred, Length: 2382, dtype: object"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f350b9a1-82a7-4051-8492-3dc762d8aaab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
